#   LibAut
#       joint onset & outcome regressions, leave-one-out CV on OUTCOME
#       LOO-CV CC for Table 1 
#   Juraj Medzihorsky
#   2021-08-12

library(GJRM)
library(parallel)
options(mc.cores=14) # written for a 16-threan linux machine
options(stringsAsFactors=F)
source('helpers.R')

#   sessionInfo()
##  R version 4.0.4 (2021-02-15)
##  Platform: x86_64-pc-linux-gnu (64-bit)
##  Running under: Ubuntu 20.04.2 LTS
##
##  Matrix products: default
##  BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.13.so
##
##  locale:
##   [1] LC_CTYPE=C.UTF-8           LC_NUMERIC=C
##   [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C.UTF-8
##   [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=C.UTF-8
##   [7] LC_PAPER=C.UTF-8           LC_NAME=C
##   [9] LC_ADDRESS=C               LC_TELEPHONE=C
##  [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
##  attached base packages:
##  [1] parallel  stats     graphics  grDevices utils     datasets  methods
##  [8] base
##
##  other attached packages:
##  [1] GJRM_0.2-4     mgcv_1.8-34    nlme_3.1-152   nvimcom_0.9-92
##
##  loaded via a namespace (and not attached):
##   [1] ismev_1.42          tidyselect_1.1.0    purrr_0.3.4
##   [4] mitools_2.4         splines_4.0.4       lattice_0.20-41
##   [7] pcaPP_1.9-73        colorspace_2.0-0    vctrs_0.3.6
##  [10] generics_0.1.0      stats4_4.0.4        gmp_0.6-2
##  [13] survival_3.2-10     rlang_0.4.10        pillar_1.4.7
##  [16] Rmpfr_0.8-4         glue_1.4.2          DBI_1.1.1
##  [19] distrEx_2.8.0       VineCopula_2.4.1    matrixStats_0.58.0
##  [22] trust_0.1-8         lifecycle_1.0.0     munsell_0.5.0
##  [25] gtable_0.3.0        pspline_1.0-18      mvtnorm_1.1-1
##  [28] VGAM_1.1-5          psych_2.1.3         magic_1.5-9
##  [31] gamlss.dist_5.3-2   ADGofTest_0.3       copula_1.0-1
##  [34] trustOptim_0.8.6.2  evd_2.3-3           Rcpp_1.0.6
##  [37] scales_1.1.1        startupmsg_0.9.6    scam_1.2-11
##  [40] tmvnsim_1.0-2       abind_1.4-5         mnormt_2.0.2
##  [43] ggplot2_3.3.3       dplyr_1.0.4         distr_2.8.0
##  [46] gsl_2.1-6           survey_4.0          numDeriv_2016.8-1.1
##  [49] grid_4.0.4          stabledist_0.7-1    tools_4.0.4
##  [52] magrittr_2.0.1      tibble_3.0.6        crayon_1.4.1
##  [55] pkgconfig_2.0.3     ellipsis_0.3.1      MASS_7.3-53.1
##  [58] Matrix_1.3-2        matrixcalc_1.0-3    assertthat_0.2.1
##  [61] R6_2.5.0            sfsmisc_1.1-11      compiler_4.0.4

#------------------------------------------------------------------------------

code_dir <- getwd()
data_dir <- gsub('code$', 'data', code_dir)

setwd(data_dir)
dat <- readRDS('onset_20210810.rds')
eps <- read.csv('eps.csv')
setwd(code_dir)

dat <- subset(dat, onset_eligible%in%1)

#------------------------------------------------------------------------------
#   get EP shares
tab_dem <- aggregate(eps$dem_ep, by=list(eps$year), FUN=mean, na.rm=T)
tab_aut <- aggregate(eps$aut_ep, by=list(eps$year), FUN=mean, na.rm=T)
names(tab_dem) <- c('year', 'share_dem_ep')
names(tab_aut) <- c('year', 'share_aut_ep')

tab <- merge(tab_dem, tab_aut, by='year')

tab_lag <- tab
tab_lag$year <- tab$year + 1
names(tab_lag)[2:3] <- paste0('lag1_', names(tab)[2:3])

tab_tab <- merge(tab, tab_lag)

dat <- merge(dat, tab_tab, by='year')

#------------------------------------------------------------------------------
#   get outcome

#   recode outcomes
#   epy: suc 1 fail 0
dat$epy <- as.numeric(rep(NA,nrow(dat)))
#   success: 1, 5
dat$epy[dat$dem_ep_outcome%in%c(1,5)] <- 1
#   fail: 2,3,4
dat$epy[dat$dem_ep_outcome%in%c(2,3,4)] <- 0
#   censored: 6
dat$epy[dat$dem_ep_outcome%in%6] <- NA

#------------------------------------------------------------------------------

dat$logpop <- log(dat$e_mipopula) 
dat$reg_fac <- as.factor(dat$e_regionpol_6C)

#------------------------------------------------------------------------------

#   filter out missing
vars_to_go <- 
    c('epy',
      'libaut_start', 
      'year',
      'reg_fac', 
      'lag1_e_migdppcln',  
      'lag1_e_migdpgro',  
      'lag1_logpop', 
      'lag1_v2pepwrsoc',  
      'lag1_v2xeg_eqdr',  
      'lag1_v2xnp_pres',  
      'lag1_v2csprtcpt', 
      'lag1_v2x_polyarchy', 
      'lag1_excl_region_edi',  
      'lag1_e_miinteco',  
      'lag1_e_miinterc', 
      'lag1_share_dem_ep', 
      'lag1_share_aut_ep')  

dat <- dat[rowSums(is.na(dat[,vars_to_go[-1]]))==0, ]
nrow(dat)

#------------------------------------------------------------------------------

#   formulas
y_0 <- epy ~ 1
s_0 <- libaut_start ~ 1
#
lin_1 <- ~ . + 
    reg_fac +
    lag1_e_migdppcln + 
    lag1_e_migdpgro + 
    lag1_logpop +
    lag1_v2pepwrsoc + 
    lag1_v2xeg_eqdr + 
    lag1_v2xnp_pres + 
    lag1_v2csprtcpt +
    lag1_v2x_polyarchy +
    lag1_excl_region_edi + 
    lag1_e_miinteco + 
    lag1_e_miinterc 
#
lin_year <- ~ . + I((1e-1*(year-1990))) + I((1e-1*(year-1990))^2) + I((1e-1*(year-1990))^3) 
lin_sel <- ~ . + lag1_share_dem_ep + lag1_share_aut_ep
#
y_l_1 <- update(y_0, lin_1)
y_l_2 <- update(y_l_1, lin_year)
#
s_l_1 <- update(update(s_0, lin_1), lin_sel)
s_l_2 <- update(update(s_l_1, lin_year), lin_sel)
#
fl1 <- list(s_l_1, y_l_1)
fl2 <- list(s_l_2, y_l_2)

#   bernoulli-logistic (x2) normal-copula GJRM for one left-out observation
#   OUTCOME fit
one <-
    function(ii, dat, form)
    {
        ii <- unlist(ii)    
        g <- gjrm(form, data=dat[-ii,],
                  Model='BSS', gamlssfit=T, extra.regI='t', 
                  margin=c('logit', 'logit'), BivD='N')
        m <- g$gam2
        p <- predict(m, newdata=dat[ii,], type='response')
        cc <- mean(round(p)==dat[ii,'epy'])
        return(data.frame(cc=cc))
    }

#------------------------------------------------------------------------------
#   execute

rr <- which(dat$epy%in%0:1)

system.time(heck_1 <- do.call(rbind, lapply(rr, function(j) one(j, dat=dat, form=fl1))))
system.time(heck_2 <- do.call(rbind, lapply(rr, function(j) one(j, dat=dat, form=fl2))))

#------------------------------------------------------------------------------
#   save   

setwd(data_dir)
save(list=c('heck_1', 'heck_2'), 
     file='joint_loo_outcome_20210812.rdata')


#   round(rbind(
#         cbind(mean(heck_1$cc),  mean(heck_2$cc))),
#         2)
#   # second column is in Table 1, column 3
#   # 0.81 0.79

#   SCRIPT END